from LinearModule_utils import *
from Attention_utils import *
from LayerNormModule_utils import *
from Activation_utils import * 
import numpy as np
from LightAttentionBackward import *
from scipy.special import softmax
from RMSnorm_utils import *
from GLU import *
#Tests on simple 10 length sequence, with 2 subsequences separated by 2 blanks
#Each token is 20 dimensional




#------------------------------------------------------------------#
#config contains the following important parameters: 
#config.memory_start : Start index of memorized embeddings (from a previous layer)
#config.seq_length : Sequence length of the smaller model that we are trying to simulate
#config.position_dim : Dimension of position embeddings (config.seq_length + config.num_blanks)
#config.num_blanks : Number of blanks to separate the sub-sequences
#config.num_attention_heads : Number of attention heads
#config.scale_embeddings : A scale to initialize different query, key matrices
#config.inner_lr : Inner learning rate to simulate sgd 
#config.gate_scale: Scale to use inside gates  
#------------------------------------------------------------------# 


class Config:
    def __init__(self, hidden_size, seq_length, position_dim, num_blanks, num_attention_heads, scale_embeddings, inner_lr, gate_scale, max_position_embeddings, scale_attn_weights):
        self.seq_length = seq_length
        self.position_dim = position_dim
        self.num_blanks = num_blanks
        self.num_attention_heads = num_attention_heads
        self.scale_embeddings = scale_embeddings
        self.inner_lr = inner_lr
        self.gate_scale = gate_scale
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.scale_attn_weights=scale_attn_weights
        self.scale_attn_by_inverse_layer_idx=False
        self.reorder_and_upcast_attn=False
        self.attn_pdrop=0.0
        self.resid_pdrop=0.0
        self.activation_function="gelu"
        self.epsilon=1e-6
        self.initial_scale=np.sqrt(4) 
        self.backprop_through_attention=False
        self.use_einsum = True
        self.restrict_blanks = False
##### Testing feedforward module ######
#W in R^{30 x 20}
#We compute W x
######## Embedding plan ########
#480-dimensional embeddings
#Memory starts from 400
#position_dim is 12
#num_blanks is 2
#num_attention_heads = 20
#scale_embeddings = 100.
#inner_lr = 1e-3
#gate_scale =10.

din=768
dout=768
projection_dout=dout
device='cpu'
config = Config(hidden_size=din*4, seq_length=1024 , position_dim=(1024+256), num_blanks=256, num_attention_heads=12, scale_embeddings=1000., inner_lr=0.01, gate_scale=10., max_position_embeddings=2048, scale_attn_weights=True)


#weights
wts=np.random.normal(size=(dout, din))
#size=(config.num_blanks, (dout // (config.num_blanks) ) * din))
bs=np.random.normal(size=(dout,))
memory_index=config.hidden_size - din


#projection_matrix = 1. / np.sqrt(projection_dout) *  np.random.normal( size=(projection_dout, dout) ) 
weights =  wts.reshape((config.num_blanks,  (projection_dout // config.num_blanks) * din ))
bias =   bs

forward_ = LinearForward(config=config, \
                         din=din, \
                         dout=projection_dout, \
                         use_softmax=False, \
                         memory_index=memory_index, \
                         projection_matrix=None, \
                        ) 

input_=np.random.normal(size=(config.seq_length, din))



#Creating the input_sequence

forward_.to(device)
hidden_states = []
position_states = []

###################### Check Forward Pass #######################

counter=0
blank_counter=0
blank_positions = np.arange(config.num_blanks)
for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.randn(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]
        
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        hid[:len(weights [blank_counter])] = torch.tensor( weights [blank_counter] )
        if blank_counter == 0:
            hid[len(weights [blank_counter]): len(weights [blank_counter])+len(bias)] = torch.tensor( bias )
            
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % (config.num_blanks)

        
hidden_states = torch.stack(hidden_states, dim=0).unsqueeze(dim=0).to(device)
position_states = torch.stack(position_states, dim=0).unsqueeze(dim=0).to(device)
forward_output = forward_.forward(hidden_states, position_states)
output = forward_output.detach().cpu().numpy()
true_output=input_ @  wts.T + np.expand_dims( bs , axis=0)



counter=0
max_difference=0.
for i in range(config.position_dim):
    if i not in blank_positions:
        max_difference=max( max_difference, np.amax(np.absolute(true_output[counter] - output[0, i, :dout])) )
        counter += 1
print ("Linear forward: Max absolute difference between coordinates:", max_difference)        
##################### Backward pass #############################
#random_back_projection=1. / np.sqrt(projection_dout) *  np.random.normal( size=(din//2, din) ) 
backward_ = Linear_Descent_Backward(config=config, \
                                    din=din, \
                                    dout=projection_dout, \
                                    use_softmax=False, \
                                    projection_matrix=None, \
                                    memory_index=memory_index,
                                   )
#descent_  = LinearDescent(config=config, din=din, dout=dout, use_softmax=False)
backward_.to(device)
#descent_.to(device)
#hidden_states = []
#position_states = []

#I push in \nabla y manually

    
counter=0
blank_counter=0
nablaw = np.zeros((projection_dout, din))
nablab = np.zeros(projection_dout)
nablax = []
nablays = []
#wts = np.reshape(weights, (dout, din))
for i in range(config.position_dim):
    if i not in blank_positions:
        #hid = torch.zeros(config.hidden_size)
        #hid[memory_index: memory_index+din] = torch.tensor( input_[counter] )
        nablay = np.random.normal(size=(projection_dout,))
        if i < config.num_blanks + config.seq_length // 2:
            nablays += [nablay]
            #print ((np.expand_dims( nablay, axis=-1 ) @ np.expand_dims( input_[counter], axis=0 )).shape)
            nablaw += np.expand_dims( nablay, axis=-1 ) @ np.expand_dims( input_[counter], axis=0 )
            nablab += nablay
            nablax += [ nablay  @ wts ]
        forward_output[0, i, :projection_dout] = torch.tensor( nablay, dtype=forward_output.dtype ).to(forward_output.device)
        forward_output[0, i, projection_dout:dout] = 0. 
        #hidden_states += [ hid ]
        #pos = torch.zeros( config.position_dim )
        #pos [counter] = 1. 
        #position_states += [ pos ]
        
        counter += 1
    #else:
        #hid = torch.zeros(config.hidden_size)
        #hid[:len(weights [blank_counter])] = torch.tensor( weights [blank_counter] )
        #if blank_counter == 0:
        #    hid[len(weights [blank_counter]): len(weights [blank_counter])+len(bias)] = torch.tensor( bias )
            
        #hidden_states += [ hid ]
        #pos = torch.zeros( config.position_dim )
        #pos [ config.seq_length + blank_counter ] = 1.
        #position_states += [ pos ]
        
        #blank_counter = (blank_counter + 1) % (config.num_blanks)

        

attention_mask = torch.ones( ( config.position_dim,  config.position_dim ) ).to(device)
attention_mask = torch.tril(attention_mask).view( (1, 1, config.position_dim,  config.position_dim)  )
right=config.num_blanks + config.seq_length // 2
for i in range(0, right ):
    attention_mask[:, :, i, i: right] = 1.

descent_update  = backward_.forward(forward_output, position_states, attention_mask).detach().cpu().numpy() 
#LinearDescent(config=config, din=din, dout=projection_dout, use_softmax=False, memory_index=memory_index)
#descent_.to(device)
#descent_update = descent_.forward(back_output_, position_states, attention_mask).detach().cpu().numpy() 


counter=0
max_difference=0.
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < (config.num_blanks + config.seq_length // 2):
            diff = np.absolute(descent_update[0, i, : din] - nablax[counter])
            #print ( np.amax(diff))
            
            max_difference=max(max_difference, np.amax(diff) )
            counter += 1
print ("Linear backward: Max absolute difference between coordinates:", max_difference)        




updated_wts = wts - config.inner_lr * nablaw
updated_b   = bias - config.inner_lr * nablab

counter=0
max_difference=0.
num_blanks=config.num_blanks
for i in range(num_blanks):
    pos = blank_positions[i]
    result=descent_update[0, pos, :din * projection_dout // num_blanks]
    max_difference=max(max_difference, np.amax(np.absolute(np.reshape(result, (-1, din)) - updated_wts [i * (projection_dout // num_blanks): (i+1) * (projection_dout // num_blanks) ] )) )
    
    #if i == 0:
    #    result=descent_update[0, pos, din * projection_dout // num_blanks:din * projection_dout // num_blanks + projection_dout]
    #    max_difference=max( max_difference, np.amax(np.absolute(result - updated_b)) )
    
print ("Linear descent: Max absolute difference between coordinates:", max_difference)        


##### Testing attention module ######
#[Q, K, V] in R^{60 x 20}
#We compute \sum_ij a_ij Vx_j, where a_ij is the attention score between i and j
din=din
memory_index=config.hidden_size-3*din
input_=np.random.normal(size=(config.seq_length, din))

#din=20
#config = Config(hidden_size=1200, memory_start=600, seq_length=10, position_dim=14, num_blanks=4, num_attention_heads=20, scale_embeddings=1000., inner_lr=1e-3, gate_scale=10., max_position_embeddings=32, scale_attn_weights=False)


#blank_positions = [0, 1, 2]
hidden_states = []
position_states = []
num_attnt_heads=12

QKV=np.random.normal(size=(3*din, din))
#V_wts=np.random.normal(size=(din, din))

#projection_matrix = 1./np.sqrt(din) * np.random.normal(size=(din, 2*din))
#projection_matrix = np.zeros((din, 2*din))
#projection_matrix[:din//2, :din] = projection_matrix[din//2:, din:] = projection_matrix_blocks

#weights=np.random.normal(size=(config.position_dim - config.seq_length, (dout // (config.position_dim - config.seq_length) ) * din))
#weights_transpose=np.reshape(np.reshape(weights, (dout, din)).T, (config.position_dim - config.seq_length, (din // (config.position_dim - config.seq_length) ) * dout) )
QKV_bs=np.random.normal(size=(3*din,))

Q_wts, K_wts, V_wts = np.split(QKV, 3)
Q_bias,  K_bias,  V_bias  = np.split(QKV_bs, 3)



Q_weights=np.reshape(Q_wts, (config.num_blanks, ( din // config.num_blanks ) * din))
K_weights=np.reshape(K_wts, (config.num_blanks, ( din // config.num_blanks ) * din))
V_weights=np.reshape(V_wts, (config.num_blanks, ( din // config.num_blanks ) * din))
#V_bias=np.random.normal(size=(din,))


##### Test on forward pass ######
attention_forward = AttentionForward(config=config, \
                                     din=din, \
                                     num_attnt_heads=num_attnt_heads, \
                                     use_softmax=False, \
                                     separate_QK=False, \
                                     memory_index=memory_index, \
                                     projection_matrix=None, \
                                    )
attention_forward.to(device)



counter=0
blank_counter=0
value_wt_tensor=torch.zeros((1, config.num_blanks, config.hidden_size))
value_wt_tensor[:, :config.num_blanks, : len(V_weights[0])] = torch.tensor(V_weights)
#print (len(V_bias), len(V_weights[0]))
value_wt_tensor[:, 0, len(V_weights[0]): len(V_weights[0]) + len(V_bias)] = torch.tensor(V_bias)

key_wt_tensor=torch.zeros((1, config.num_blanks, config.hidden_size))
key_wt_tensor[:, :config.num_blanks, : len(K_weights[0])] = torch.tensor(K_weights)
#print (len(V_bias), len(V_weights[0]))
key_wt_tensor[:, 0, len(V_weights[0]): len(K_weights[0]) + len(V_bias)] = torch.tensor(K_bias)




for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.randn(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]
        
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        hid[:len(Q_weights [blank_counter])] = torch.tensor( Q_weights [blank_counter] )
        if blank_counter == 0:
            hid[len(Q_weights [blank_counter]): len(Q_weights [blank_counter])+len(Q_bias)] = torch.tensor( Q_bias )
           
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % (config.num_blanks)
value_wt_tensor = value_wt_tensor.to(device)        
hidden_states = torch.stack(hidden_states, dim=0).unsqueeze(dim=0).to(device)
position_states = torch.stack(position_states, dim=0).unsqueeze(dim=0).to(device)
output_ = attention_forward.forward(hidden_states, position_states, key_wt_tensor, value_wt_tensor)      
output = output_.detach().cpu().numpy()


#wts=np.reshape(weights, (dout, din))
Q=input_ @ Q_wts.T + np.expand_dims(Q_bias, axis=0)
#QKV=input_ @ QK_wts.T + np.expand_dims(bias, axis=0)

K=input_ @ K_wts.T + np.expand_dims(K_bias, axis=0) 
V=input_ @ V_wts.T + np.expand_dims(V_bias, axis=0)



Q = np.swapaxes(np.reshape(Q, (-1, num_attnt_heads, din // num_attnt_heads )), 0, 1)
K = np.swapaxes(np.swapaxes(np.reshape(K, (-1, num_attnt_heads, din // num_attnt_heads )), 0, 1), 1, 2)
V = np.swapaxes(np.reshape(V, (-1, num_attnt_heads, din // num_attnt_heads )), 0, 1)
in_ = np.swapaxes(np.reshape(input_, (-1, num_attnt_heads, din // num_attnt_heads )), 0, 1)

mask = np.ones((num_attnt_heads, len(input_), len(input_)))
mask = 1. - np.tril(mask)

#def softmax(a, axis, mask):
#    softmax = torch.nn.Softmax()
    
    #if config.scale_attn_weights:
    #    a = config.initial_scale * a / np.sqrt(config.hidden_size // config.num_attention_heads)
    #expa = np.exp(a - np.amax(a, axis=axis, keepdims=True))    
    #return (expa * (1.-mask)) / np.sum(expa * (1.-mask), axis=axis, keepdims=True)

attn_scores = Q @ K
if config.scale_attn_weights:
    attn_scores = config.initial_scale * attn_scores / np.sqrt(config.hidden_size // config.num_attention_heads)
attn_scores [np.where(mask == 1.)] = np.finfo(np.float64).min

attn_scores = softmax(attn_scores, axis=-1)

true_output = np.reshape(np.swapaxes(attn_scores @ V, 0, 1), (-1, din))
true_in = np.reshape(np.swapaxes(attn_scores @ in_, 0, 1), (-1, din))

max_differnce=0.
counter=0
for i in range(config.num_blanks + config.seq_length // 2):
    if i not in blank_positions:
        #if i < config.position_dim // 2:
        diff=np.absolute(true_output[counter] - output[0, i, :din])
        #print (counter, np.amax(diff))
        max_differnce = max(max_differnce, np.amax(diff) )
        counter += 1
            
print ("Attention forward: Max absolute difference between coordinates:", max_differnce)        
    

##### Test on backward pass ######
#attn_scores=np.reshape(attn_scores, (-))
#attn_scores = np.reshape( np.swapaxes(attn_scores, 0, 1), (-1, num_attnt_heads*len(input_)) )
counter = 0
blank_counter = 0
#hidden_states = []
#position_states = []
#attention_backward = LightAttentionBackward (config, din=din, num_attnt_heads=num_attnt_heads, use_softmax=False, retain_nablay=False, memory_index=memory_index)
attention_descent = LightAttentionBackward_Descent (config, din, num_attnt_heads, use_softmax=False, memory_index=memory_index, retain_nablay=False)
#attention_backward.to(device)
attention_descent.to(device)
nablays = []

for i in range(config.position_dim):
    if i not in blank_positions:
        nablay = np.random.normal(size=(din,))
        nablays += [nablay] 
        output_[0, i, :din] = torch.tensor( nablay, dtype=output_.dtype ).to(output_.device)
        
        

value_wt_tensor=torch.zeros((1, config.num_blanks, config.hidden_size))
value_wt_tensor[:, :config.num_blanks, : len(V_weights[0])] = torch.tensor(V_weights)
value_wt_tensor[:, 0, len(V_weights[0]): len(V_weights[0]) + len(V_bias)] = torch.tensor(V_bias)

output_[:, :config.num_blanks] = value_wt_tensor.to(output_.device)
attention_mask = torch.ones( ( config.position_dim,  config.position_dim ) ).to(device)

num_blanks=config.num_blanks
attention_mask = torch.tril(attention_mask).view( ( 1, 1, config.position_dim,  config.position_dim)  )
right= config.num_blanks + config.seq_length // 2
for i in range( 0, right ):
    attention_mask[:, :, i, i: right] = 1.


#torch.finfo(attention_mask.dtype).min
#backoutput_ = attention_backward.forward(output_, position_states, attention_mask).detach()  
#backout = backoutput_.detach().cpu().numpy()


grad_value = np.zeros((len(input_), din))
for j in range(len(input_)):
    for i in range(num_attnt_heads):
        for k in range(len(input_)):
            if k < config.seq_length // 2:
                grad_value[j, i* (din // num_attnt_heads): (i+1)* (din // num_attnt_heads)] += attn_scores[i, k, j] * nablays[k][i* (din // num_attnt_heads): (i+1)* (din // num_attnt_heads)] 
                

                
                
#counter=0
#max_difference=0.
#for i in range( config.num_blanks + config.seq_length // 2 ):  
#    if i not in blank_positions:
#        result=backout[0, i, din : 2*din]
#        true_result = grad_value [ counter ] @ V_wts
        
#        difference=np.absolute( result - true_result )
         #print ( np.amax(difference) )
#        counter += 1
#        max_difference=max(max_difference, np.amax(difference) )


#print ("Attetnion backward: Max absolute difference between coordinates:", max_difference)                        

 
descent_update = attention_descent.forward(output_, position_states, attention_mask)               
nabla_value    = grad_value

updated_V_wts = V_wts - config.inner_lr * nabla_value[:config.seq_length // 2].T @ input_[:config.seq_length // 2]
updated_b     = V_bias - config.inner_lr * np.sum(nabla_value[:config.seq_length // 2], axis=0)


nablays = np.asarray(nablays)
updated_V_wts_prime = V_wts - config.inner_lr * nablays[:config.seq_length // 2].T @ true_in[:config.seq_length // 2]
#updated_b_prime     = V_bias - config.inner_lr * np.sum(nablays[:config.seq_length // 2], axis=0)


counter=0
max_difference=0.
num_blanks=config.position_dim - config.seq_length
descent_update=descent_update.detach().cpu().numpy()
dout=3*din
for i in range(num_blanks):
    pos = blank_positions[i]
    result=descent_update[0, pos, :din * din // num_blanks]
    difference=np.absolute(np.reshape(result, (-1, din)) - updated_V_wts [i * (din // num_blanks): (i+1) * (din // num_blanks) ] )
    max_difference=max(max_difference, np.amax(difference) )
    
    #if i == 0:
    #    result=descent_update[0, pos, din * din // num_blanks:din * din // num_blanks + din]
    #    max_difference=max( max_difference, np.amax(np.absolute(result - updated_b)) )
    
print ("Attetnion descent: Max absolute difference between coordinates:", max_difference)        


####################### Test Layernorm ####################
#din=20
#config = Config(hidden_size=1200, seq_length=10, position_dim=14, num_blanks=4, num_attention_heads=20, scale_embeddings=1000., inner_lr=1e-3, gate_scale=10., max_position_embeddings=32, scale_attn_weights=False)

memory_index=config.hidden_size-2*din
#blank_positions = [0, 1, 2]
hidden_states = []
position_states = []

weights=np.diag(np.random.normal(size=(din,))).reshape((config.num_blanks, (din // config.num_blanks) * din))
wts=weights.reshape((din, din))
bias=np.random.normal(size=(din,))

##### Test on forward pass ######
layernorm_forward = LayerNormForward(config=config, \
                                     din=din, \
                                     use_softmax=False, \
                                     memory_index=memory_index, \
                                    )
layernorm_forward.to(device)

counter=0
blank_counter=0
for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.randn(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]
        
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        hid[:len(weights [blank_counter])] = torch.tensor( weights[blank_counter] )
        
        if blank_counter == 0:
            hid[len(weights [blank_counter]): len(weights [blank_counter])+len(bias)] = torch.tensor( bias )
           
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % config.num_blanks    
    
hidden_states = torch.stack(hidden_states, dim=0).unsqueeze(dim=0).to(device)
position_states = torch.stack(position_states, dim=0).unsqueeze(dim=0).to(device)   
layernorm_output = layernorm_forward.forward(hidden_states, position_states)
output=layernorm_output.detach().cpu().numpy()     



counter=0
max_difference=0.

ys = []
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < config.seq_length // 2 + config.num_blanks:

            prediction = input_[counter]
            mean = np.mean(prediction)
            var  = np.mean((input_[counter] - mean) ** 2) ** 0.5
            ys += [(prediction - mean) / var]
            
            prediction = wts @ (prediction - mean) / var + bias
            max_difference=max( max_difference, np.amax(np.absolute(prediction - output[0, i, :din])) )
            counter += 1
print ("Layernorm Forward: Max absolute difference between coordinates:", max_difference)        


##### Test on backward pass ######
layernorm_backward = LayerNormBackward(config=config, \
                                       din=din, \
                                       use_softmax=False, \
                                       memory_index=memory_index, \
                                       retain_nablay=False, \
                                      )
#layernorm_descent  = LayerNormDescent(config=config, din=din, use_softmax=False, memory_index=memory_index)

layernorm_backward.to(device)
#layernorm_descent.to(device)


counter=0
blank_counter=0
true_nablay = []
true_nablax = []
nablazs = []
nablaw = np.zeros((din, din))
nablab = np.zeros(din)
for i in range(config.position_dim):
    if i not in blank_positions:
        nablaz = np.random.normal(size=(din,))
        nablazs += [nablaz]
        layernorm_output[ :, i, :din ] = torch.tensor( nablaz ).to(device)    
            
        prediction = input_[counter]
        #hidden_states[ :, i, config.memory_start + din: config.memory_start + 2*din] = torch.tensor( prediction ).to(device)
        nablay = nablaz @ wts
        
        true_nablay += [ nablay ]
        
        
        mean = np.mean(prediction)
        var  = np.mean((input_[counter] - mean) ** 2) ** 0.5
        prediction = (prediction - mean) / var
        if i < config.num_blanks + config.seq_length // 2:
            nablaw += np.expand_dims( nablaz, axis=-1 ) @ np.expand_dims( prediction, axis = 0 )
            nablab += nablaz   
        true_nablax += [ 1./var * ( nablay - np.mean(nablay) - 1./din * np.dot(nablay, prediction) * prediction ) ]
        
        #hidden_states[ :, i, config.memory_start: config.memory_start + din] = torch.tensor( prediction ).to(device)
        counter += 1
        
output_ = layernorm_backward.forward(layernorm_output, position_states, attention_mask)   
#descent_update = layernorm_descent.forward(output_, position_states, attention_mask)   
#descent_update=descent_update.detach().cpu().numpy()

 
output = output_.detach().cpu().numpy()

counter=0
max_difference=0.
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < config.num_blanks + config.seq_length // 2:
            prediction = true_nablax[counter] 
            max_difference=max( max_difference, np.amax(np.absolute(prediction - output[0, i, :din])) )
            counter += 1
print ("Layernorm Backward: Max absolute difference between coordinates:", max_difference)      


#counter=0
#max_difference=0.
#for i in range(config.position_dim):
#    if i not in blank_positions:
#        if i < config.num_blanks + config.seq_length // 2:
#            prediction = true_nablax[counter] 
#            max_difference=max( max_difference, np.amax(np.absolute(prediction - descent_update[0, i, : din])) )
#            counter += 1
#print ("Layernorm Backward: Max absolute difference between coordinates:", max_difference)      



#updated_wts = wts - config.inner_lr * nablaw
#updated_b   = bias - config.inner_lr * nablab

#updated_wts = np.diag(updated_wts)
#.reshape((config.position_dim - config.seq_length, (din // (config.position_dim - config.seq_length) ) * din))

#counter=0
#max_difference=0.
#num_blanks=config.position_dim - config.seq_length
#dout=din
#for i in range(num_blanks):
#    pos = blank_positions[i]
#    result=descent_update[0, pos, :din * dout // num_blanks]
    #print (np.reshape(result, (-1, din)), updated_wts [i * (dout // num_blanks): (i+1) * (dout // num_blanks) ] )
#    difference=np.absolute(np.reshape(result, (-1, din)) - updated_wts [i * (dout // num_blanks): (i+1) * (dout // num_blanks) ] )
#    max_difference=max(max_difference, np.amax(difference) )
    
#    if i == 0:
#        result=descent_update[0, pos, din * dout // num_blanks:din * dout // num_blanks + dout]
#        max_difference=max( max_difference, np.amax(np.absolute(result - updated_b)) )
    
#print ("Layernorm Descent: Max absolute difference between coordinates:", max_difference)        

###### Activation test #######
memory_index=config.hidden_size-din

din=din
dout=din
input_=np.random.normal(size=(config.seq_length, din))
#projection_matrix = 1. / np.sqrt(dout) *  np.random.normal( size=(dout, din) ) 


counter=0
blank_counter=0
blank_positions = np.arange(config.num_blanks)

#Creating the input_sequence
device='cpu'
forward_.to(device)
hidden_states = []
position_states = []


for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.randn(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]
        
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % (config.num_blanks)

        
            
        
activation_forward = ActivationForward (config=config, din=din, memory_index=memory_index, projection_matrix=None)
activation_forward.to(device)



##### Test on forward pass ######
counter=0
blank_counter=0
hidden_states = []
position_states = []




for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.zeros(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]    
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % config.num_blanks    
    
hidden_states = torch.stack(hidden_states, dim=0).unsqueeze(dim=0).to(device)
position_states = torch.stack(position_states, dim=0).unsqueeze(dim=0).to(device)   
activation_output = activation_forward.forward(hidden_states, position_states) 

counter=0
max_difference=0.
num_blanks=config.num_blanks
dout=din
gelu=torch.nn.GELU()
gelu_grad = []

for i in range(config.position_dim):
    if i not in blank_positions:
        if i < config.seq_length // 2 + config.num_blanks:
            dout = din
            #projection_matrix.shape[0]
            gelu_out =  gelu(hidden_states[0, i, : din])
            gelu_grad += [ config.scale_embeddings * ( gelu(hidden_states[0, i, : din] + 1./config.scale_embeddings) - gelu(hidden_states[0, i, : din]) ) ]
            
            max_difference=max( max_difference, torch.max(torch.absolute(gelu_out - activation_output[0, i, : dout])).item() )
            counter += 1
print ("Activation forward: Max absolute difference between coordinates:", max_difference)      
#exit(0)

##### Test on backward pass ######

backward_projection=1. / np.sqrt(din) *  np.random.normal( size=(din, dout) ) 
print (din, dout, backward_projection.shape)
activation_backward = ActivationBackward (config=config, din=din, memory_index=memory_index, input_projection=backward_projection, projection_matrix=None)
activation_backward.to(device)




counter = 0
nablays = []
for i in range(config.position_dim):
    if i not in blank_positions:
        nablay = np.random.normal(size=(dout,))
        nablays += [nablay]
        activation_output[:, i, :dout] = torch.tensor( nablay ).to(device)
        #hidden_states[:, i, memory_index: memory_index + din] = torch.tensor( input_[counter] ).to(device)
        counter += 1   

        
backward = activation_backward.forward(activation_output, position_states, attention_mask)
     
    
    
counter=0
max_difference=0.
num_blanks=config.num_blanks
gelu=torch.nn.GELU()

all_ = []
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < config.seq_length // 2 + config.num_blanks:
            
            gradient = gelu_grad[counter] * activation_output[0, i, : dout]
            projected_gradient = gradient
            all_ += [projected_gradient.cpu().detach().numpy()]
            diff = torch.max( torch.absolute( projected_gradient - backward[0, i, : dout] ) ).item()
            max_difference=max( max_difference, diff )
            #print (max_difference)
            counter += 1
print ("Activation backward: Max absolute difference between coordinates:", max_difference)      




## Check RMS layernorm
memory_index=config.hidden_size-2*din
#blank_positions = [0, 1, 2]
hidden_states = []
position_states = []

weights=np.diag(np.random.normal(size=(din,))).reshape((config.num_blanks, (din // config.num_blanks) * din))
wts=weights.reshape((din, din))
bias=np.random.normal(size=(din,))

##### Test on forward pass ######
layernorm_forward = RMSLayerNormForward(config=config, \
                                     din=din, \
                                     use_softmax=False, \
                                     memory_index=memory_index, \
                                    )
layernorm_forward.to(device)

counter=0
blank_counter=0
for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.randn(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]
        
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        hid[:len(weights [blank_counter])] = torch.tensor( weights[blank_counter] )
        
        if blank_counter == 0:
            hid[len(weights [blank_counter]): len(weights [blank_counter])+len(bias)] = torch.tensor( bias )
           
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % config.num_blanks    
    
hidden_states = torch.stack(hidden_states, dim=0).unsqueeze(dim=0).to(device)
position_states = torch.stack(position_states, dim=0).unsqueeze(dim=0).to(device)   
layernorm_output = layernorm_forward.forward(hidden_states, position_states)
output=layernorm_output.detach().cpu().numpy()     



counter=0
max_difference=0.

ys = []
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < config.seq_length // 2 + config.num_blanks:

            prediction = input_[counter]
            mean = np.mean(prediction)
            var  = np.mean((input_[counter] ) ** 2) ** 0.5
            ys += [(prediction) / var]
            
            prediction = wts @ (prediction) / var + bias
            max_difference=max( max_difference, np.amax(np.absolute(prediction - output[0, i, :din])) )
            counter += 1
print ("RMS Layernorm Forward: Max absolute difference between coordinates:", max_difference)        


##### Test on backward pass ######
layernorm_backward = RMSLayerNormBackward(config=config, \
                                       din=din, \
                                       use_softmax=False, \
                                       memory_index=memory_index, \
                                       retain_nablay=False, \
                                      )
#layernorm_descent  = LayerNormDescent(config=config, din=din, use_softmax=False, memory_index=memory_index)

layernorm_backward.to(device)
#layernorm_descent.to(device)


counter=0
blank_counter=0
true_nablay = []
true_nablax = []
nablazs = []
nablaw = np.zeros((din, din))
nablab = np.zeros(din)
for i in range(config.position_dim):
    if i not in blank_positions:
        nablaz = np.random.normal(size=(din,))
        nablazs += [nablaz]
        layernorm_output[ :, i, :din ] = torch.tensor( nablaz ).to(device)    
            
        prediction = input_[counter]
        #hidden_states[ :, i, config.memory_start + din: config.memory_start + 2*din] = torch.tensor( prediction ).to(device)
        nablay = nablaz @ wts
        
        true_nablay += [ nablay ]
        
        
        mean = np.mean(prediction)
        var  = np.mean((input_[counter]) ** 2) ** 0.5
        prediction = (prediction) / var
        if i < config.num_blanks + config.seq_length // 2:
            nablaw += np.expand_dims( nablaz, axis=-1 ) @ np.expand_dims( prediction, axis = 0 )
            nablab += nablaz   
        true_nablax += [ 1./var * ( nablay - 1./din * np.dot(nablay, prediction) * prediction ) ]
        
        #hidden_states[ :, i, config.memory_start: config.memory_start + din] = torch.tensor( prediction ).to(device)
        counter += 1
        
output_ = layernorm_backward.forward(layernorm_output, position_states, attention_mask)   
#descent_update = layernorm_descent.forward(output_, position_states, attention_mask)   
#descent_update=descent_update.detach().cpu().numpy()

 
output = output_.detach().cpu().numpy()

counter=0
max_difference=0.
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < config.num_blanks + config.seq_length // 2:
            prediction = true_nablax[counter] 
            max_difference=max( max_difference, np.amax(np.absolute(prediction - output[0, i, :din])) )
            counter += 1
print ("RMS Layernorm Backward: Max absolute difference between coordinates:", max_difference)      


##### Test Gates Units
memory_index = config.hidden_size-3*din
forward_ = GLUForward(config, din, use_softmax=False, memory_index=memory_index)
dout = din

gelu_act=forward_.activation.mlp_module.act
#torch.nn.GELU()

#weights
wts= 1./np.sqrt(1.) * np.random.normal(size=(dout, din))
#size=(config.num_blanks, (dout // (config.num_blanks) ) * din))
bs= 1./np.sqrt(1.) * np.random.normal(size=(dout,))

activation_wts = 1./np.sqrt(1.) * np.random.normal(size=(dout, din))
activation_bs  = 1./np.sqrt(1.) *  np.random.normal(size=(dout, ))

#projection_matrix = 1. / np.sqrt(projection_dout) *  np.random.normal( size=(projection_dout, dout) ) 
weights = wts.reshape((config.num_blanks,  (dout // config.num_blanks) * din ))
bias = bs

activation_weights = activation_wts.reshape((config.num_blanks,  (dout // config.num_blanks) * din ))
activation_bias = activation_bs


input_= 1./np.sqrt(1.) * np.random.normal(size=(config.seq_length, din))


counter=0
blank_counter=0
blank_positions = np.arange(config.num_blanks)
hidden_states = []
position_states = []
for i in range(config.position_dim):
    if i not in blank_positions:
        hid = torch.randn(config.hidden_size)
        hid[:din] = torch.tensor( input_[counter] )
                
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [counter] = 1. 
        position_states += [ pos ]
        
        
        counter += 1
    else:
        hid = torch.zeros(config.hidden_size)
        hid[:len(weights [blank_counter])] = torch.tensor( weights [blank_counter] )
        if blank_counter == 0:
            hid[len(weights [blank_counter]): len(weights [blank_counter])+len(bias)] = torch.tensor( bias )
            
        hidden_states += [ hid ]
        pos = torch.zeros( config.position_dim )
        pos [ config.seq_length + blank_counter ] = 1.
        position_states += [ pos ]
        
        blank_counter = (blank_counter + 1) % (config.num_blanks)

hidden_states = torch.stack(hidden_states, dim=0).unsqueeze(dim=0).to(device)

activation_wt_tensor=torch.zeros((1, config.num_blanks, config.hidden_size))
activation_wt_tensor[:, :config.num_blanks, : len(activation_weights[0])] = torch.tensor(activation_weights)
activation_wt_tensor[:, 0, len(activation_weights[0]): len(activation_weights[0]) + len(activation_bias)] = torch.tensor(activation_bias)


position_states = torch.stack(position_states, dim=0).unsqueeze(dim=0).to(device)
forward_output = forward_.forward(hidden_states, position_states, activation_wt_tensor)

true_gate_input = input_ @ activation_wts.T  + np.expand_dims( activation_bs , axis=0)  
true_gate_output = gelu_act( torch.tensor( true_gate_input ) ).detach().cpu().numpy()
true_linear = input_ @ wts.T  + np.expand_dims( bs , axis=0)

true_output = true_linear * true_gate_output
output = forward_output.detach().cpu().numpy()


counter=0
max_difference=0.
#gelu_grad = []
for i in range(config.position_dim):
    if i not in blank_positions:
        #gelu_grad += [ config.scale_embeddings * ( gelu(forward_output[0, i, memory_index+din: memory_index+2*din] + 1./config.scale_embeddings) - gelu(forward_output[0, i, memory_index+din: memory_index+2*din]) ).detach().cpu().numpy() ]
        max_difference=max( max_difference, np.amax(np.absolute(true_output[counter] - output[0, i, :dout])) )
        counter += 1
print ("Linear GLU forward: Max absolute difference between coordinates:", max_difference)        



backward_ = GLUBackward_Descent(config, din, use_softmax=False, memory_index=memory_index)



counter=0
blank_counter=0
nablaw = np.zeros((dout, din))
nablab = np.zeros(dout)
nabla_act_w = np.zeros((dout, din))
nabla_act_b = np.zeros(dout)

nablax = []
nablax_p = []
nablays = []
gelu_grads = []

for i in range(config.position_dim):
    if i not in blank_positions:
        nablay = 1./np.sqrt(1.) * np.random.normal(size=(dout,))
        if i < config.num_blanks + config.seq_length // 2:
            nablays += [nablay]

            nablaw += np.expand_dims( nablay * true_gate_output[counter], axis=-1 ) @ np.expand_dims( input_[counter], axis=0 )
            nablab += nablay * true_gate_output[counter]
            
            
            gelu_grad = config.scale_embeddings * ( gelu_act(forward_output[0, i, memory_index+din: memory_index+2*din] + 1./config.scale_embeddings * torch.tensor(nablay * true_linear[counter]) ) - gelu_act(forward_output[0, i, memory_index+din: memory_index+2*din]) ).detach().cpu().numpy()
            gelu_grads += [gelu_grad]
            
            nabla_act_w += np.expand_dims( gelu_grad, axis=-1 ) @ np.expand_dims( input_[counter], axis=0 )
            nabla_act_b += gelu_grad
            
            nablax += [ (nablay * true_gate_output[counter]) @ wts + gelu_grad @ activation_wts ]
            nablax_p += [ gelu_grad @ activation_wts ]
        forward_output[0, i, :dout] = torch.tensor( nablay, dtype=forward_output.dtype ).to(forward_output.device)
        counter += 1
    
        

attention_mask = torch.ones( ( config.position_dim,  config.position_dim ) ).to(device)
attention_mask = torch.tril(attention_mask).view( (1, 1, config.position_dim,  config.position_dim)  )
right=config.num_blanks + config.seq_length // 2
for i in range(0, right ):
    attention_mask[:, :, i, i: right] = 1.

descent_update, activation_descent_update  = backward_.forward(forward_output, position_states, attention_mask, activation_wt_tensor)
descent_update = descent_update.detach().cpu().numpy()
activation_descent_update = activation_descent_update.detach().cpu().numpy()



#LinearDescent(config=config, din=din, dout=projection_dout, use_softmax=False, memory_index=memory_index)
#descent_.to(device)
#descent_update = descent_.forward(back_output_, position_states, attention_mask).detach().cpu().numpy() 






counter=0
max_difference=0.
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < (config.num_blanks + config.seq_length // 2):
            diff = np.absolute(activation_descent_update[0, i, : din] - nablax_p[counter])
            #print ( counter, np.amax(diff))
            #if np.amax(diff) > 50.:
            #    print (activation_descent_update[0, i, : din] , nablax_p[counter])
            
            max_difference=max(max_difference, np.amax(diff) )
            counter += 1
print ("GLU backward: (part) Max absolute difference between coordinates:", max_difference)        


counter=0
max_difference=0.
for i in range(config.position_dim):
    if i not in blank_positions:
        if i < (config.num_blanks + config.seq_length // 2):
            diff = np.absolute(descent_update[0, i, : din] - nablax[counter])
            #print ( np.amax(diff))
            
            max_difference=max(max_difference, np.amax(diff) )
            counter += 1
print ("GLU backward: Max absolute difference between coordinates:", max_difference)        




updated_wts = wts - config.inner_lr * nablaw
updated_b   = bias - config.inner_lr * nablab


updated_act_wts = activation_wts - config.inner_lr * nabla_act_w
updated_act_b   = activation_bias - config.inner_lr * nabla_act_b

counter=0
max_difference=0.
num_blanks=config.num_blanks
for i in range(num_blanks):
    pos = blank_positions[i]
    result=descent_update[0, pos, :din * dout // num_blanks]
    max_difference=max(max_difference, np.amax(np.absolute(np.reshape(result, (-1, din)) - updated_wts [i * (dout // num_blanks): (i+1) * (projection_dout // num_blanks) ] )) )
    
    if i == 0:
        result=descent_update[0, pos, din * projection_dout // num_blanks:din * projection_dout // num_blanks + projection_dout]
        max_difference=max( max_difference, np.amax(np.absolute(result - updated_b)) )
    
print ("Linear descent (linear weights): Max absolute difference between coordinates:", max_difference)        


counter=0
max_difference=0.
num_blanks=config.num_blanks
for i in range(num_blanks):
    
    pos = blank_positions[i]
    result=activation_descent_update[0, pos, :din * dout // num_blanks]
    max_diff = np.amax(np.absolute(np.reshape(result, (-1, din)) - updated_act_wts [i * (dout // num_blanks): (i+1) * (dout // num_blanks) ] ))
    max_difference=max(max_difference, max_diff )
    #print (max_diff)
    
    
    if i == 0:
        result=activation_descent_update[0, pos, din * dout // num_blanks:din * dout // num_blanks + dout]
        max_difference=max( max_difference, np.amax(np.absolute(result - updated_act_b)) )
        #print (result - updated_act_b)
    
print ("Linear descent (gate weights): Max absolute difference between coordinates:", max_difference)        

